import tensorflow as tf
from tensorflow import keras
from yolo import yolo_body  # 加载网络的预测头

# 构造模型
def MODEL(input_shape, num_anchors, num_classes, summary=False):

    # 接收模型结构
    model = yolo_body(input_shape, num_anchors, num_classes)

    # 是否查看模型结构
    if summary:
        model.summary()
    # 返回模型
    return model

# 验证
if __name__ == '__main__':

    # 构造输入层
    MODEL(input_shape=[416,416,3],
          num_anchors=3,  # 每个网格生成3个先验框
          num_classes=20,  # VOC数据集20个类别 
          summary=True)
